// Copyright Epic Games, Inc. All Rights Reserved.

#include <zencore/callstack.h>
#include <zencore/thread.h>

#if ZEN_PLATFORM_WINDOWS
#	include <zencore/windows.h>
#	include <Dbghelp.h>
#endif

#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
#	include <execinfo.h>
#endif

#if ZEN_WITH_TESTS
#	include <zencore/testing.h>
#endif

#include <fmt/format.h>

namespace zen {
#if ZEN_PLATFORM_WINDOWS

class WinSymbolInit
{
public:
	WinSymbolInit() {}
	~WinSymbolInit()
	{
		m_CallstackLock.WithExclusiveLock([this]() {
			if (m_Initialized)
			{
				SymCleanup(m_CurrentProcess);
			}
		});
	}

	bool GetSymbol(void* Frame, SYMBOL_INFO* OutSymbolInfo, DWORD64& OutDisplacement)
	{
		bool Result = false;
		m_CallstackLock.WithExclusiveLock([&]() {
			if (!m_Initialized)
			{
				m_CurrentProcess = GetCurrentProcess();
				if (SymInitialize(m_CurrentProcess, NULL, TRUE) == TRUE)
				{
					m_Initialized = true;
				}
			}
			if (m_Initialized)
			{
				if (SymFromAddr(m_CurrentProcess, (DWORD64)Frame, &OutDisplacement, OutSymbolInfo) == TRUE)
				{
					Result = true;
				}
			}
		});
		return Result;
	}

private:
	HANDLE m_CurrentProcess = NULL;
	BOOL   m_Initialized	= FALSE;
	RwLock m_CallstackLock;
};

static WinSymbolInit WinSymbols;

#endif

CallstackFrames*
CreateCallstack(uint32_t FrameCount, void** Frames) noexcept
{
	if (FrameCount == 0)
	{
		return nullptr;
	}
	CallstackFrames* Callstack = (CallstackFrames*)malloc(sizeof(CallstackFrames) + sizeof(void*) * FrameCount);
	if (Callstack != nullptr)
	{
		Callstack->FrameCount = FrameCount;
		if (FrameCount == 0)
		{
			Callstack->Frames = nullptr;
		}
		else
		{
			Callstack->Frames = (void**)&Callstack[1];
			memcpy(Callstack->Frames, Frames, sizeof(void*) * FrameCount);
		}
	}
	return Callstack;
}

CallstackFrames*
CloneCallstack(const CallstackFrames* Callstack) noexcept
{
	if (Callstack == nullptr)
	{
		return nullptr;
	}
	return CreateCallstack(Callstack->FrameCount, Callstack->Frames);
}

void
FreeCallstack(CallstackFrames* Callstack) noexcept
{
	if (Callstack != nullptr)
	{
		free(Callstack);
	}
}

uint32_t
GetCallstack(int FramesToSkip, int FramesToCapture, void* OutAddresses[])
{
#if ZEN_PLATFORM_WINDOWS
	return (uint32_t)CaptureStackBackTrace(FramesToSkip, FramesToCapture, OutAddresses, 0);
#endif
#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
	void* Frames[FramesToSkip + FramesToCapture];
	int	  FrameCount = backtrace(Frames, FramesToSkip + FramesToCapture);
	if (FrameCount > FramesToSkip)
	{
		for (int Index = FramesToSkip; Index < FrameCount; Index++)
		{
			OutAddresses[Index - FramesToSkip] = Frames[Index];
		}
		return (uint32_t)(FrameCount - FramesToSkip);
	}
	else
	{
		return 0;
	}
#endif
}

std::vector<std::string>
GetFrameSymbols(uint32_t FrameCount, void** Frames)
{
	std::vector<std::string> FrameSymbols;
	if (FrameCount > 0)
	{
		FrameSymbols.resize(FrameCount);
#if ZEN_PLATFORM_WINDOWS
		char		 SymbolBuffer[sizeof(SYMBOL_INFO) + 1024];
		SYMBOL_INFO* SymbolInfo	 = (SYMBOL_INFO*)SymbolBuffer;
		SymbolInfo->SizeOfStruct = sizeof(SYMBOL_INFO);
		SymbolInfo->MaxNameLen	 = 1023;
		DWORD64 Displacement	 = 0;
		for (uint32_t FrameIndex = 0; FrameIndex < FrameCount; FrameIndex++)
		{
			if (WinSymbols.GetSymbol(Frames[FrameIndex], SymbolInfo, Displacement))
			{
				FrameSymbols[FrameIndex] = fmt::format("{}+{:#x} [{:#x}]", SymbolInfo->Name, Displacement, (uintptr_t)Frames[FrameIndex]);
			}
		}
#endif
#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
		char** messages = backtrace_symbols(Frames, (int)FrameCount);
		if (messages)
		{
			for (uint32_t FrameIndex = 0; FrameIndex < FrameCount; FrameIndex++)
			{
				FrameSymbols[FrameIndex] = messages[FrameIndex];
			}
			free(messages);
		}
#endif
	}
	return FrameSymbols;
}

void
FormatCallstack(const CallstackFrames* Callstack, StringBuilderBase& SB)
{
	bool First = true;
	for (const std::string& Symbol : GetFrameSymbols(Callstack))
	{
		if (!First)
		{
			SB.Append("\n");
		}
		else
		{
			First = false;
		}
		SB.Append(Symbol);
	}
}

std::string
CallstackToString(const CallstackFrames* Callstack)
{
	StringBuilder<2048> SB;
	FormatCallstack(Callstack, SB);
	return SB.ToString();
}

#if ZEN_WITH_TESTS

TEST_CASE("Callstack.Basic")
{
	void*	 Addresses[4];
	uint32_t FrameCount = GetCallstack(1, 4, Addresses);
	CHECK(FrameCount > 0);
	std::vector<std::string> Symbols = GetFrameSymbols(FrameCount, Addresses);
	for (const std::string& Symbol : Symbols)
	{
		CHECK(!Symbol.empty());
	}
}

void
callstack_forcelink()
{
}

#endif

}  // namespace zen
